Kruskal Algorithm¶

  1. Sort all the edges in non-decreasing order of their weight.
  2. Pick the smallest edge. Check if it forms a cycle with the spanning tree formed so far.
    If the cycle is not formed, include this edge. Else, discard it.
  3. Repeat step-2 until there are $(V-1)$ edges in the spanning tree.
Note: This algorithm works for undirected graphs only.

Wrapper Class for Graph Visualization¶

In [1]:
import networkx as nx 
import matplotlib.pyplot as plt 

# Defining a Class 
class GraphVisualization: 

    def __init__(self, weighted, edge_list = [], adjancency_matrix = [], isDirected = False): 
        self.weighted = weighted
        self.G = (nx.DiGraph() if isDirected else nx.Graph())
        
        if len(edge_list) > 0:
            for i in edge_list:
                self.G.add_edge(chr(i[0] + 65), chr(i[1] + 65), weight = i[2])
        
        elif len(adjancency_matrix) > 0:
            for i in range(len(adjancency_matrix)):
                for j in range(len(adjancency_matrix[i])):
                    if adjancency_matrix[i][j] <= 0: continue
                    self.G.add_edge(chr(i + 65), chr(j + 65), weight = adjancency_matrix[i][j])
        
        elif len(edge_list) == 0 and len(adjancency_matrix) == 0:
            raise Exception("I expect atleast an edge-list or an adjancency matrix")
    
    # In visualize function G is an object of 
    # class Graph given by networkx G.add_edges_from(visual) 
    # creates a graph with a given list 
    # nx.draw_networkx(G) - plots the graph 
    # plt.show() - displays the graph 
    def visualize(self):
        pos = nx.spring_layout(self.G, scale = 5000)

        # Manually scale up the positions for more spacing
        for key in pos:
            pos[key] *= 10000
        nx.draw_networkx(self.G, pos, node_size=700, node_color='#00ccff', font_size=10)

        if self.weighted:
            # Draw edge labels for weights
            labels = nx.get_edge_attributes(self.G, 'weight')
            nx.draw_networkx_edge_labels(self.G, pos, edge_labels=labels)
        plt.show()

Function to convert from edge-list to adjancency-matrix¶

In [2]:
def edge_list_to_adjacency_matrix(edge_list, directed=False):
    # get number of vertices
    s = set()
    for i in edge_list:
        s.add(i[0]); s.add(i[1])
    num_vertices = len(s)
    s = list(s)
    # Initialize an n x n matrix with all 0's
    adj_matrix = [[0 for _ in range(num_vertices)] for _ in range(num_vertices)]
    
    # Iterate through each edge in the edge list
    for edge in edge_list:
        u, v = edge[:2]  # Get the vertices of the edge
        adj_matrix[s.index(u)][s.index(v)] = edge[2]  # Mark edge u -> v with 1 (or edge weight if any)
        
        if not directed:
            adj_matrix[s.index(v)][s.index(u)] = edge[2]  # For undirected graphs, also mark edge v -> u
        else: adj_matrix[s.index(v)][s.index(u)] = -edge[2]
    
    return adj_matrix

Class to check for loops in the graph¶

In [3]:
class helper:
    def __init__(self):
        self.v_arr = []

    def check_loop(self, a_matrix: list, index = 0, prev_vertex = 0) -> bool:
        if index in self.v_arr: return True
        self.v_arr.append(index)
        for i in range(len(a_matrix[index])):
            if a_matrix[index][i] > 0 and i != prev_vertex:
                isLoop = self.check_loop(a_matrix, i, index)
                if isLoop: return True
        return False

This is the graph we will be using for our example¶

In [6]:
edgeList = [
    (0, 1, 5),
    (0, 2, 3),
    (1, 2, 4),
    (1, 3, 6),
    (2, 3, 5),
    (2, 5, 6),
    (3, 5, 6),
    (3, 4, 6),
    (4, 5, 3),
    (4, 6, 4),
    (5, 6, 4),
    (1, 4, 2)
]
G = GraphVisualization(weighted = True, edge_list = edgeList)
G.visualize()
No description has been provided for this image

Adjancency matrix of the graph¶

In [7]:
edge_list_to_adjacency_matrix(edgeList)
Out[7]:
[[0, 5, 3, 0, 0, 0, 0],
 [5, 0, 4, 6, 2, 0, 0],
 [3, 4, 0, 5, 0, 6, 0],
 [0, 6, 5, 0, 6, 6, 0],
 [0, 2, 0, 6, 0, 3, 4],
 [0, 0, 6, 6, 3, 0, 4],
 [0, 0, 0, 0, 4, 4, 0]]

Sorting the edges based on their cost¶

In [8]:
sorted_edges = sorted(edgeList, key = lambda x: x[2])
for edge in sorted_edges:
    print(f'{chr(edge[0] + 65)} -- {edge[2]} --> {chr(edge[1] + 65)}')
B -- 2 --> E
A -- 3 --> C
E -- 3 --> F
B -- 4 --> C
E -- 4 --> G
F -- 4 --> G
A -- 5 --> B
C -- 5 --> D
B -- 6 --> D
C -- 6 --> F
D -- 6 --> F
D -- 6 --> E

Kruskal Algorithm driver code¶

In [9]:
def kruskal_MST(edges):
    stack = [edges[0]]
    for i in range(1, len(edges)):
        stack.append(edges[i])
        if helper().check_loop(edge_list_to_adjacency_matrix(stack)):
            stack.pop()
    return stack
In [10]:
mst = kruskal_MST(sorted_edges)
for edge in mst:
    print(f'{chr(edge[0] + 65)} -- {edge[2]} --> {chr(edge[1] + 65)}')
B -- 2 --> E
A -- 3 --> C
E -- 3 --> F
B -- 4 --> C
E -- 4 --> G
C -- 5 --> D

Final minuimum spanning tree¶

In [11]:
G5 = GraphVisualization(weighted = True, edge_list = mst)
G5.visualize()
No description has been provided for this image

Total cost of MST¶

In [12]:
print(f"Total cost of MST = {sum([i[2] for i in mst])}")
Total cost of MST = 21

Note: A graph can have more than 1 spanning tree.